-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[NFC][PGO] Factor downscaling of branch weights out of Instrumentation
into ProfileData
#153735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-pgo @llvm/pr-subscribers-llvm-ir Author: Mircea Trofin (mtrofin) ChangesThe logic isn’t instrumentation-specific, and the refactoring allows users avoid a dependency on InstrumentationFull diff: https://github.com/llvm/llvm-project/pull/153735.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index ca56e4aa81575..16c6e8ed1af9c 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);
+/// downscale the given weights preserving the ratio. If the maximum value is
+/// not already known and not provided via \param KnownMaxCount , it will be
+/// obtained from \param Weights.
+LLVM_ABI SmallVector<uint32_t>
+downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount = std::nullopt);
+
+/// Calculate what to divide by to scale counts.
+///
+/// Given the maximum count, calculate a divisor that will scale all the
+/// weights to strictly less than std::numeric_limits<uint32_t>::max().
+LLVM_ABI inline uint64_t calculateCountScale(uint64_t MaxCount) {
+ return MaxCount < std::numeric_limits<uint32_t>::max()
+ ? 1
+ : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
+}
+
+/// Scale an individual branch count.
+///
+/// Scale a 64-bit weight down to 32-bits using \c Scale.
+///
+LLVM_ABI inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
+ uint64_t Scaled = Count / Scale;
+ assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+ return Scaled;
+}
+
/// Specify that the branch weights for this terminator cannot be known at
/// compile time. This should only be called by passes, and never as a default
/// behavior in e.g. MDBuilder. The goal is to use this info to validate passes
diff --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
index 962d9e734a40a..93ab8c693607f 100644
--- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h
+++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
SanitizerCoverageOptions() = default;
};
-/// Calculate what to divide by to scale counts.
-///
-/// Given the maximum count, calculate a divisor that will scale all the
-/// weights to strictly less than std::numeric_limits<uint32_t>::max().
-static inline uint64_t calculateCountScale(uint64_t MaxCount) {
- return MaxCount < std::numeric_limits<uint32_t>::max()
- ? 1
- : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
-}
-
-/// Scale an individual branch count.
-///
-/// Scale a 64-bit weight down to 32-bits using \c Scale.
-///
-static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
- uint64_t Scaled = Count / Scale;
- assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
- return Scaled;
-}
-
// Use to ensure the inserted instrumentation has a DebugLocation; if none is
// attached to the source instruction, try to use a DILocation with offset 0
// scoped to surrounding function (if it has a DebugLocation).
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1b5f67689e6d..489fbfef00e4d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
+SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount) {
+ uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
+ : *llvm::max_element(Weights);
+ assert(MaxCount > 0 && "Bad max count");
+ uint64_t Scale = calculateCountScale(MaxCount);
+ SmallVector<unsigned, 4> DownscaledWeights;
+ for (const auto &ECI : Weights)
+ DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
+ return DownscaledWeights;
+}
+
void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
assert(T != 0 && "Caller should guarantee");
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index e0b22ef94d064..374f64069f4e0 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t MaxCount) {
- assert(MaxCount > 0 && "Bad max count");
- uint64_t Scale = calculateCountScale(MaxCount);
- SmallVector<unsigned, 4> Weights;
- for (const auto &ECI : EdgeCounts)
- Weights.push_back(scaleBranchCount(ECI, Scale));
+ auto Weights = downscaleWeights(EdgeCounts, MaxCount);
LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
: Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t TotalCount =
std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
[](uint64_t c1, uint64_t c2) { return c1 + c2; });
- Scale = calculateCountScale(WSum);
+ auto Scale = calculateCountScale(WSum);
BranchProbability BP(scaleBranchCount(Weights[0], Scale),
scaleBranchCount(WSum, Scale));
std::string BranchProbStr;
|
@llvm/pr-subscribers-llvm-transforms Author: Mircea Trofin (mtrofin) ChangesThe logic isn’t instrumentation-specific, and the refactoring allows users avoid a dependency on InstrumentationFull diff: https://github.com/llvm/llvm-project/pull/153735.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index ca56e4aa81575..16c6e8ed1af9c 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);
+/// downscale the given weights preserving the ratio. If the maximum value is
+/// not already known and not provided via \param KnownMaxCount , it will be
+/// obtained from \param Weights.
+LLVM_ABI SmallVector<uint32_t>
+downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount = std::nullopt);
+
+/// Calculate what to divide by to scale counts.
+///
+/// Given the maximum count, calculate a divisor that will scale all the
+/// weights to strictly less than std::numeric_limits<uint32_t>::max().
+LLVM_ABI inline uint64_t calculateCountScale(uint64_t MaxCount) {
+ return MaxCount < std::numeric_limits<uint32_t>::max()
+ ? 1
+ : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
+}
+
+/// Scale an individual branch count.
+///
+/// Scale a 64-bit weight down to 32-bits using \c Scale.
+///
+LLVM_ABI inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
+ uint64_t Scaled = Count / Scale;
+ assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+ return Scaled;
+}
+
/// Specify that the branch weights for this terminator cannot be known at
/// compile time. This should only be called by passes, and never as a default
/// behavior in e.g. MDBuilder. The goal is to use this info to validate passes
diff --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
index 962d9e734a40a..93ab8c693607f 100644
--- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h
+++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
SanitizerCoverageOptions() = default;
};
-/// Calculate what to divide by to scale counts.
-///
-/// Given the maximum count, calculate a divisor that will scale all the
-/// weights to strictly less than std::numeric_limits<uint32_t>::max().
-static inline uint64_t calculateCountScale(uint64_t MaxCount) {
- return MaxCount < std::numeric_limits<uint32_t>::max()
- ? 1
- : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
-}
-
-/// Scale an individual branch count.
-///
-/// Scale a 64-bit weight down to 32-bits using \c Scale.
-///
-static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
- uint64_t Scaled = Count / Scale;
- assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
- return Scaled;
-}
-
// Use to ensure the inserted instrumentation has a DebugLocation; if none is
// attached to the source instruction, try to use a DILocation with offset 0
// scoped to surrounding function (if it has a DebugLocation).
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1b5f67689e6d..489fbfef00e4d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
+SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount) {
+ uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
+ : *llvm::max_element(Weights);
+ assert(MaxCount > 0 && "Bad max count");
+ uint64_t Scale = calculateCountScale(MaxCount);
+ SmallVector<unsigned, 4> DownscaledWeights;
+ for (const auto &ECI : Weights)
+ DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
+ return DownscaledWeights;
+}
+
void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
assert(T != 0 && "Caller should guarantee");
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index e0b22ef94d064..374f64069f4e0 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t MaxCount) {
- assert(MaxCount > 0 && "Bad max count");
- uint64_t Scale = calculateCountScale(MaxCount);
- SmallVector<unsigned, 4> Weights;
- for (const auto &ECI : EdgeCounts)
- Weights.push_back(scaleBranchCount(ECI, Scale));
+ auto Weights = downscaleWeights(EdgeCounts, MaxCount);
LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
: Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t TotalCount =
std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
[](uint64_t c1, uint64_t c2) { return c1 + c2; });
- Scale = calculateCountScale(WSum);
+ auto Scale = calculateCountScale(WSum);
BranchProbability BP(scaleBranchCount(Weights[0], Scale),
scaleBranchCount(WSum, Scale));
std::string BranchProbStr;
|
llvm/include/llvm/IR/ProfDataUtils.h
Outdated
/// | ||
/// Given the maximum count, calculate a divisor that will scale all the | ||
/// weights to strictly less than std::numeric_limits<uint32_t>::max(). | ||
LLVM_ABI inline uint64_t calculateCountScale(uint64_t MaxCount) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation to annotate LLVM_ABI
for the helper functions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are also used in the Instrumentation
component, but... hmm. I suppose since they are defined in the header they don't need it.
5b1cbec
to
8c2429e
Compare
8c2429e
to
859616d
Compare
…on` into `ProfileData`
859616d
to
5ffa9b3
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/10/builds/11471 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/88/builds/15080 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/163/builds/24712 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/19921 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/18710 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/175/builds/23450 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/137/builds/23610 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/123/builds/25270 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/185/builds/23447 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/144/builds/32863 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/8/builds/19599 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/12529 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/206/builds/4809 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/158/builds/10638 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/105/builds/11073 Here is the relevant piece of the build log for the reference
|
The logic isn’t instrumentation-specific, and the refactoring allows users avoid a dependency on
Instrumentation
and just take one onProfileData
(which a fairly low-level dependency)